{
 "metadata": {
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from model import TTSR\n",
    "from loss.loss import get_loss_dict\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision.utils as utils\n",
    "\n",
    "from model import LTE\n",
    "from trainer import *\n",
    "from dataloader.dataset import ImageDataset\n",
    "\n",
    "from utils import calc_psnr_and_ssim,get_crop_im,get_clothes\n",
    "from model import Vgg19,MainNet\n",
    "\n",
    "from imageio import imread, imsave\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "# Prevent error：image file is truncated (# bytes not processed)\n",
    "from PIL import ImageFile\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "# GPU configurations\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "cpu = torch.device(\"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "source": [
    "# Load data"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "_model = TTSR.TTSR().to(device)\n",
    "t = Trainer(_model)\n",
    "\n",
    "data = torch.load('/home/xincz/Documents/Tensors/tensor.pt')\n",
    "tar_ds = ImageDataset(list(data.keys()), data)\n",
    "tar_dl = DataLoader(tar_ds, shuffle=False, batch_size=64, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "-----lrid:A1905, -----rank: 1 ----, ------sim: 0.1633, -----elapsed: 10.39\n",
      "-----lrid:A1916, -----rank: 11 ----, ------sim: 0.3878, -----elapsed: 11.60\n",
      "-----lrid:A1923, -----rank: 1 ----, ------sim: 0.4100, -----elapsed: 12.20\n",
      "-----lrid:A1924, -----rank: 1 ----, ------sim: 0.5000, -----elapsed: 11.29\n",
      "-----lrid:A1928, -----rank: 8 ----, ------sim: 0.6667, -----elapsed: 12.00\n"
     ]
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-18de333f704b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mindir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/xincz/Documents/Datasets/models_540'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0ms_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearch_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'models_540-1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtar_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0ms_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/img_search/trainer.py\u001b[0m in \u001b[0;36msearch_list\u001b[0;34m(self, indir, file_name, tar_dl)\u001b[0m\n\u001b[1;32m    112\u001b[0m                 \u001b[0;31m# Single query\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m                 \u001b[0mt1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m                 \u001b[0mtmp_res1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtmp1_sim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_scale_res\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtar_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlrid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m                 \u001b[0msearch_res\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlrid\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtmp_res1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtmp1_sim\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/img_search/trainer.py\u001b[0m in \u001b[0;36mget_scale_res\u001b[0;34m(tar_dl, lr1, lr2, lrid)\u001b[0m\n\u001b[1;32m    136\u001b[0m             \u001b[0mres2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mres2\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mget_maxrel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtar_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 138\u001b[0;31m         \u001b[0mres1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mres1_sim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_res\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlrid\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# get top-20\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    139\u001b[0m         \u001b[0mres2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mres2_sim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_res\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlrid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    140\u001b[0m         \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mres1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlrid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/img_search/trainer.py\u001b[0m in \u001b[0;36mget_res\u001b[0;34m(res, lrid)\u001b[0m\n\u001b[1;32m    145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_res\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlrid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreverse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# 按照某一列进行排序的方法\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    148\u001b[0m     \u001b[0mres_20\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    149\u001b[0m     \u001b[0mres_20_sim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0massigned\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWRAPPER_ASSIGNMENTS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m     \u001b[0;34m@\u001b[0m\u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0massigned\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0massigned\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moverrides\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhas_torch_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "indir = '/home/xincz/Documents/Datasets/models_540'\n",
    "s_time = time.time()\n",
    "res = t.search_list(indir, 'models_540-1', tar_dl)\n",
    "print(time.time() - s_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}